# -*- coding: utf-8 -*-
"""
Created on Wed Aug 31 09:56:10 2022

@author: wulong
"""
import numpy as np
from casadi import *
from Day_model import* 

# In[1] Day-ahead Optimization
def formulate_opt_h(initial_state, state_guess, state_bnds_lo, state_bnds_up, 
                    input_cont_guess, input_cont_bnds_lo, input_cont_bnds_up, 
                    input_cont_lbg, input_cont_ubg, 
                    input_bin_guess, input_bin_bnds_lo, input_bin_bnds_up, 
                    distb_bnds,
                    output_guess, 
                    y2_setpnts_lo, y2_setpnts_up, 
                    output_bnds_lo, output_bnds_up,
                    alpha, alpha_slack, pred_horzn,
                    pf, pmg, pse, 
                    pcm, ppn, rxi, rxi_as):
    w = []
    w0 = []
    lbw = []
    ubw = []
    discrete = []
    J = 0
    g = []
    lbg = []
    ubg = []
    
    # "Lift" initial states conditoins
    Xk = initial_state
    
    for k in range(1, pred_horzn+1):
        # Create input variables
        uname = 'U' + str(k-1)
        Uk = MX.sym(uname, Nuc_h)
        w += [Uk]
        lbw +=[input_cont_bnds_lo]
        ubw +=[input_cont_bnds_up]
        w0 += [input_cont_guess[k-1,:]]
        discrete += [False]*Nuc_h
        
        zname = 'Z' + str(k-1)
        Zk = MX.sym(zname, Nuz_h)
        w += [Zk]
        lbw +=[input_bin_bnds_lo]
        ubw +=[input_bin_bnds_up]
        w0 += [input_bin_guess[k-1,:]]
        discrete += [True]*Nuz_h
        
        #Add input constraints
        Zu = vertcat(Zk[0], Zk[1], Zk[2], 1, 1)
        g += [Uk - Zu*input_cont_lbg]
        g += [Zu*input_cont_ubg - Uk]
        lbg += [[0]*Nuc_h*2]
        ubg += [[np.inf]*Nuc_h*2]
        
        # Create disturbance
        Dk = distb_bnds[k-1,:]
        
        # Simulate the model
        Ik = I_ode_h(x0 = Xk, p = vertcat(Uk, Zk, Dk))
        X_int = Ik['xf']
        
        # Create new states variables
        xname = 'X' + str(k)
        Xk = MX.sym(xname, Nx_h)
        w += [Xk]
        lbw += [state_bnds_lo]
        ubw += [state_bnds_up]
        w0 += [state_guess[k-1,:]]
        discrete += [False]*Nx_h
        
        # Add dynamic constraints
        g += [X_int - Xk]
        lbg += [[0]*Nx_h]
        ubg += [[0]*Nx_h]
        
        # Create outputs variables
        yname = 'y' + str(k)
        Yk = MX.sym(yname, Ny_h)
        w += [Yk]
        lbw += [output_bnds_lo]
        ubw += [output_bnds_up]
        w0 += [output_guess[k-1,:]]
        discrete += [False]*Ny_h
        
        # Add ouput constraints
        g += [Yk - out_ies_h(Xk, Uk, Zk, Dk)]
        lbg += [[0]*Ny_h]
        ubg += [[0]*Ny_h]
        
        # Add slack variables
        ename = 'e' + str(k)
        ek = MX.sym(ename, 2)
        w += [ek]
        lbw += [[-np.inf]*2]
        ubw += [[np.inf]*2]
        w0 += [[0]*2]
        discrete += [False]*2
        
        # Add baseline power variable yeb for y1
        yebname = 'yeb' + str(k)
        yebk = MX.sym(yebname, 1)
        w += [yebk]
        lbw += [[8]*1]
        ubw += [[80]*1]
        w0 += [[40]*1]
        discrete += [False]*1
        
        # Sell/Compensation/Penalty Price, Regulation factor
        pse_k = pse[k-1]
        pcm_k = pcm[k-1]
        ppn_k = ppn[k-1]
        rxi_k = rxi[k-1]
        rxi_as_k = rxi_as[k-1]
        
        # Add target constraints
        g += [Yk[0] - (1 + rxi_k)*yebk + ek[0]]
        lbg += [[0]*1]
        ubg += [[0]*1]
        
        g += [Yk[1] + ek[1]]
        lbg += [y2_setpnts_lo[k-1]]
        ubg += [y2_setpnts_up[k-1]]
        
        # The cost function
        J += alpha * (pf*(Uk[0] + Uk[1]) + ppn_k*(Yk[0] - (1 + rxi_k)*yebk)**2
                      - pmg*Dk[2] - pse_k*Yk[0] - pcm_k*rxi_as_k*yebk)
        J += alpha_slack[0] * ek[0]**2
        J += alpha_slack[1] * ek[1]**2
        
        pass
    
    # Concatenate decision variables and constraint terms
    w = vertcat(*w)
    lbw = vertcat(*lbw)
    ubw = vertcat(*ubw)
    w0 = vertcat(*w0)
    g = vertcat(*g)
    lbg = vertcat(*lbg)
    ubg = vertcat(*ubg)
    
    return w, lbw, ubw, w0, g, lbg, ubg, discrete, J

# In[2] Creat Day-ahead Optimization solver
def solve_opt_h(w, lbw, ubw, w0, g, lbg, ubg, discrete, J):
    print('Creat an Day-ahead solver')
    nlp_prob = {'f': J, 'x': w, 'g': g}
    nlp_solver = nlpsol('nlp_solver', 'bonmin', nlp_prob, {"discrete": discrete})
    # Solve NLP
    print('Solve Day-ahead')
    sol = nlp_solver(x0=w0, lbx=lbw, ubx=ubw, lbg=lbg, ubg=ubg)
    print(nlp_solver.stats())
    optimalValues = sol['x'].full().ravel()
    
    return optimalValues